# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy
import argparse
import os
import random
import time
from distutils.util import strtobool

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
import envs
from vizdoom import gym_wrapper
import logging
from bdtime import tt

# envs.log.setLevel(logging.DEBUG)
# log = logging.getLogger("test1")
# logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s: - %(message)s',
#                     datefmt='%Y-%m-%d %H:%M:%S',
#                     level=logging.DEBUG)
# log.info('haha')

# exit()


def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
                        help="the name of this experiment")
    parser.add_argument("--seed", type=int, default=1,
                        help="seed of the experiment")
    parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
                        help="if toggled, `torch.backends.cudnn.deterministic=False`")
    parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
                        help="if toggled, cuda will be enabled by default")

    parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
                        help="if toggled, this experiment will be tracked with Weights and Biases")
    parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
                        help="the wandb's project name")
    parser.add_argument("--wandb-entity", type=str, default=None,
                        help="the entity (team) of wandb's project")
    parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
                        help="weather to capture videos of the agent performances (check out `videos` folder)")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="CartPole-v1",
                        help="the id of the environment")
    parser.add_argument("--total-timesteps", type=int, default=500000,
                        help="total timesteps of the experiments")
    parser.add_argument("--learning-rate", type=float, default=2.5e-4,
                        help="the learning rate of the optimizer")
    parser.add_argument("--buffer-size", type=int, default=10000,
                        help="the replay memory buffer size")
    parser.add_argument("--gamma", type=float, default=0.99,
                        help="the discount factor gamma")
    parser.add_argument("--target-network-frequency", type=int, default=500,
                        help="the timesteps it takes to update the target network")
    parser.add_argument("--batch-size", type=int, default=128,
                        help="the batch size of sample from the reply memory")
    parser.add_argument("--start-e", type=float, default=1,
                        help="the starting epsilon for exploration")
    parser.add_argument("--end-e", type=float, default=0.05,
                        help="the ending epsilon for exploration")
    parser.add_argument("--exploration-fraction", type=float, default=0.5,
                        help="the fraction of `total-timesteps` it takes from start-e to go end-e")
    parser.add_argument("--learning-starts", type=int, default=10000,
                        help="timestep to start learning")
    parser.add_argument("--train-frequency", type=int, default=10,
                        help="the frequency of training")

    # region 自定义参数
    parser.add_argument("--mylog", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
                        help="if toggled, program will use tensorboard to log msg, and open the h5 url.")

    parser.add_argument("--autoOpenUrl", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
                        help="if toggled, program will auto open tensorboard url.")
    parser.add_argument("--logLevel", type=str, default="DEBUG",
                        help="the name of this experiment")
    # endregion

    args = parser.parse_args()

    # 仅输出到屏幕, 不保存到日志
    # logging.basicConfig(
    #     # format='%(asctime)s - %(name)s.%(lineno)s - %(levelname)s: - %(message)s',
    #     format="%(asctime)s %(pathname)s %(filename)s %(funcName)s %(lineno)s - %(levelname)s - %(message)s",
    #     datefmt='%Y-%m-%d %H:%M:%S',
    #     # filename='test.log',
    #     level=logging.DEBUG
    # )
    # log_config_dc = {
    #     # "format": "%(asctime)s | %(pathname)s | %(lineno)s | %(funcName)s | %(levelname)s: %(message)s",
    #     'format': '[%(asctime)s] [%(filename)s:%(lineno)d] [%(module)s:%(funcName)s] [%(levelname)s]- %(message)s',
    #     # 'format': '[%(asctime)s] [%(filename)s:%(lineno)d:%(funcName)s] [%(levelname)s]- %(message)s',
    #
    #     # "datefmt": "%Y-%m-%d %H:%M:%S",
    #     "datefmt": tt.common_date_time_formats.s_dt,
    #     # "level": logging.DEBUG,
    # }
    from bdtime import log_config_dc
    # log.setLevel(logging.INFO)
    logLevel = getattr(args, 'logLevel')
    if logLevel:
        print('logLevel --- ', logLevel)
        logLevel = logLevel.upper()
        level_types = list(logging._nameToLevel.keys())
        assert logLevel in level_types, f'logLevel[{logLevel}]取值错误! 取值范围: {level_types}'
        # assert hasattr(logging, logLevel), f'logLevel[{logLevel}]取值错误!'
        log_config_dc.usual.update({"level": getattr(logging, logLevel)})

    logging.basicConfig(**log_config_dc.usual)

    envs.init()
    # envs.log.debug("哈哈哈哈哈")
    # envs.log.info("asdfsadfasdf")
    # envs.log.warning("warning!")
    #
    #
    # envs.log.warning("exitexitexit!")
    # exit()
    return args


def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk


# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )

    def forward(self, x):
        return self.network(x)


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)


if __name__ == "__main__":
    args = parse_args()
    from bdtime import tt, show_json

    # show_json(args)
    # if 1:
    # args.track = True
    # args.wandb_project_name

    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )

    logdir = f"runs/{run_name}"

    # --- 这里自动打开tensorboard
    if args.mylog:
        port = 8123
        mylog_cmd = f"tensorboard --logdir={logdir} --port {port}"
        from shells import kill_port

        kill_port(port)
        tt.sleep(0.2)
        print('\n\n ====== 清理tensorboard的旧端口完毕 ======\n\n')

        mylog_url = f"http://localhost:{port}/"
        print(f'*** mylog_cmd: {mylog_cmd}')
        print(f'*** mylog_url: {mylog_url}')

        # 后台运行
        import subprocess
        import os
        import platform

        subprocess.Popen(mylog_cmd, shell=True)
        if args.autoOpenUrl:

            if platform.system() == "Windows":
                start_cmd = 'start'
            else:
                start_cmd = 'open'
            os.system(f'{start_cmd} {mylog_url}')
            # os.open(mylog_url)

        # os.system(mylog_cmd)

    print(f'*** tensorboard logdir: {logdir}')

    writer = SummaryWriter(logdir)
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
    assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

    q_network = QNetwork(envs).to(device)
    optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
    target_network = QNetwork(envs).to(device)
    target_network.load_state_dict(q_network.state_dict())

    rb = ReplayBuffer(
        args.buffer_size,
        envs.single_observation_space,
        envs.single_action_space,
        device,
        handle_timeout_termination=True,
    )
    start_time = time.time()

    from tqdm import tqdm
    from bdtime import tt

    # PRT_INTERVAL = 1
    tqdm_i = tqdm(total=args.total_timesteps)
    tt.__init__()

    # TRY NOT TO MODIFY: start the game
    obs = envs.reset()
    for global_step in range(args.total_timesteps):
        # ALGO LOGIC: put action logic here
        epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps,
                                  global_step)
        if random.random() < epsilon:
            actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
        else:
            logits = q_network(torch.Tensor(obs).to(device))
            actions = torch.argmax(logits, dim=1).cpu().numpy()

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards, dones, infos = envs.step(actions)

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        for info in infos:
            if "episode" in info.keys():
                # print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                # tqdm_i.desc = f"SPS: {int(global_step / tt.now())} | "
                episodic_return = info['episode']['r']
                tqdm_i.desc = f"global_step={global_step}, episodic_return={episodic_return}"
                # tqdm_i.desc = "global_step={:>10d}, episodic_return={:>6.2f}".format(global_step, episodic_return)
                writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
                writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
                writer.add_scalar("charts/epsilon", epsilon, global_step)
                break

        # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
        real_next_obs = next_obs.copy()
        for idx, d in enumerate(dones):
            if d:
                real_next_obs[idx] = infos[idx]["terminal_observation"]
        rb.add(obs, real_next_obs, actions, rewards, dones, infos)

        # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
        obs = next_obs

        # ALGO LOGIC: training.
        if global_step > args.learning_starts and global_step % args.train_frequency == 0:
            data = rb.sample(args.batch_size)
            with torch.no_grad():
                target_max, _ = target_network(data.next_observations).max(dim=1)
                td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
            old_val = q_network(data.observations).gather(1, data.actions).squeeze()
            loss = F.mse_loss(td_target, old_val)

            if global_step % 100 == 0:
                writer.add_scalar("losses/td_loss", loss, global_step)
                writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                # print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

            # optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update the target network
            if global_step % args.target_network_frequency == 0:
                target_network.load_state_dict(q_network.state_dict())
        tqdm_i.update(1)

        # if tt.now(0) % 5 == 0:
    envs.close()
    writer.close()
